{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading Caffe-model weights for SSD"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import cv2 as cv\n",
    "from motrackers.detectors.caffe import Caffe_SSDMobileNet\n",
    "from motrackers import CentroidTracker, CentroidKF_Tracker, SORT, IOUTracker\n",
    "from motrackers.utils import draw_tracks\n",
    "import ipywidgets as widgets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "abd6e0e1d0aa46138a7fc205c466f217",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Select(description='MOTracker:', options=('CentroidTracker', 'CentroidKF_Tracker', 'SORT', 'IOUTracker'), valu…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "chosen_tracker = widgets.Select(\n",
    "    options=[\"CentroidTracker\", \"CentroidKF_Tracker\", \"SORT\", \"IOUTracker\"],\n",
    "    value='CentroidTracker',\n",
    "    rows=5,\n",
    "    description='MOTracker:',\n",
    "    disabled=False\n",
    ")\n",
    "chosen_tracker"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "if chosen_tracker.value == 'CentroidTracker':\n",
    "    tracker = CentroidTracker(max_lost=0, tracker_output_format='mot_challenge')\n",
    "elif chosen_tracker.value == 'CentroidKF_Tracker':\n",
    "    tracker = CentroidKF_Tracker(max_lost=0, tracker_output_format='mot_challenge')\n",
    "elif chosen_tracker.value == 'SORT':\n",
    "    tracker = SORT(max_lost=3, tracker_output_format='mot_challenge', iou_threshold=0.3)\n",
    "elif chosen_tracker.value == 'IOUTracker':\n",
    "    tracker = IOUTracker(max_lost=2, iou_threshold=0.5, min_detection_confidence=0.4, max_detection_confidence=0.7,\n",
    "                         tracker_output_format='mot_challenge')\n",
    "else:\n",
    "    print(\"Please choose one tracker from the above list.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "VIDEO_FILE = \"./../video_data/cars.mp4\"\n",
    "WEIGHTS_PATH = \"./../pretrained_models/caffemodel_weights/MobileNetSSD_deploy.caffemodel\"\n",
    "CONFIG_FILE_PATH = \"./../pretrained_models/caffemodel_weights/MobileNetSSD_deploy.prototxt\"\n",
    "LABELS_PATH=\"./../pretrained_models/caffemodel_weights/ssd_mobilenet_caffe_names.json\"\n",
    "\n",
    "CONFIDENCE_THRESHOLD = 0.5\n",
    "NMS_THRESHOLD = 0.2\n",
    "DRAW_BOUNDING_BOXES = True\n",
    "USE_GPU=False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Caffe_SSDMobileNet(\n",
    "    weights_path=WEIGHTS_PATH,\n",
    "    configfile_path=CONFIG_FILE_PATH,\n",
    "    labels_path=LABELS_PATH,\n",
    "    confidence_threshold=CONFIDENCE_THRESHOLD,\n",
    "    nms_threshold=NMS_THRESHOLD, \n",
    "    draw_bboxes=DRAW_BOUNDING_BOXES,\n",
    "    use_gpu=USE_GPU\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "def main(video_path, model, tracker):\n",
    "\n",
    "    cap = cv.VideoCapture(video_path)\n",
    "    while True:\n",
    "        ok, image = cap.read()\n",
    "\n",
    "        if not ok:\n",
    "            print(\"Cannot read the video feed.\")\n",
    "            break\n",
    "\n",
    "        image = cv.resize(image, (700, 500))\n",
    "\n",
    "        bboxes, confidences, class_ids = model.detect(image)\n",
    "        \n",
    "        tracks = tracker.update(bboxes, confidences, class_ids)\n",
    "        \n",
    "        updated_image = model.draw_bboxes(image.copy(), bboxes, confidences, class_ids)\n",
    "\n",
    "        updated_image = draw_tracks(updated_image, tracks)\n",
    "\n",
    "        cv.imshow(\"image\", updated_image)\n",
    "        if cv.waitKey(1) & 0xFF == ord('q'):\n",
    "            break\n",
    "\n",
    "    cap.release()\n",
    "    cv.destroyAllWindows()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "main(VIDEO_FILE, model, tracker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "work_env",
   "language": "python",
   "name": "work_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
